/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.flink.api.java.typeutils.runtime;

import org.apache.flink.annotation.Internal;
import org.apache.flink.api.common.typeutils.CompositeTypeComparator;
import org.apache.flink.api.common.typeutils.TypeComparator;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.core.memory.DataInputView;
import org.apache.flink.core.memory.DataOutputView;
import org.apache.flink.core.memory.MemorySegment;
import org.apache.flink.types.NullKeyFieldException;
import org.apache.flink.util.InstantiationUtil;

import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.lang.reflect.Field;
import java.util.List;

@Internal
public final class PojoComparator<T> extends CompositeTypeComparator<T>
        implements java.io.Serializable {

    private static final long serialVersionUID = 1L;

    // Reflection fields for the comp fields
    private transient Field[] keyFields;

    private final TypeComparator<Object>[] comparators;

    private final int[] normalizedKeyLengths;

    private final int numLeadingNormalizableKeys;

    private final int normalizableKeyPrefixLen;

    private final boolean invertNormKey;

    private TypeSerializer<T> serializer;

    private final Class<T> type;

    @SuppressWarnings("unchecked")
    public PojoComparator(
            Field[] keyFields,
            TypeComparator<?>[] comparators,
            TypeSerializer<T> serializer,
            Class<T> type) {
        this.keyFields = keyFields;
        this.comparators = (TypeComparator<Object>[]) comparators;

        this.type = type;
        this.serializer = serializer;

        // set up auxiliary fields for normalized key support
        this.normalizedKeyLengths = new int[keyFields.length];
        int nKeys = 0;
        int nKeyLen = 0;
        boolean inverted = false;

        for (Field keyField : keyFields) {
            keyField.setAccessible(true);
        }

        for (int i = 0; i < this.comparators.length; i++) {
            TypeComparator<?> k = this.comparators[i];
            if (k == null) {
                throw new IllegalArgumentException("One of the passed comparators is null");
            }
            if (keyFields[i] == null) {
                throw new IllegalArgumentException("One of the passed reflection fields is null");
            }

            // as long as the leading keys support normalized keys, we can build up the composite
            // key
            if (k.supportsNormalizedKey()) {
                if (i == 0) {
                    // the first comparator decides whether we need to invert the key direction
                    inverted = k.invertNormalizedKey();
                } else if (k.invertNormalizedKey() != inverted) {
                    // if a successor does not agree on the inversion direction, it cannot be part
                    // of the normalized key
                    break;
                }

                nKeys++;
                final int len = k.getNormalizeKeyLen();
                if (len < 0) {
                    throw new RuntimeException(
                            "Comparator "
                                    + k.getClass().getName()
                                    + " specifies an invalid length for the normalized key: "
                                    + len);
                }
                this.normalizedKeyLengths[i] = len;
                nKeyLen += this.normalizedKeyLengths[i];

                if (nKeyLen < 0) {
                    // overflow, which means we are out of budget for normalized key space anyways
                    nKeyLen = Integer.MAX_VALUE;
                    break;
                }
            } else {
                break;
            }
        }
        this.numLeadingNormalizableKeys = nKeys;
        this.normalizableKeyPrefixLen = nKeyLen;
        this.invertNormKey = inverted;
    }

    @SuppressWarnings("unchecked")
    private PojoComparator(PojoComparator<T> toClone) {
        this.keyFields = toClone.keyFields;
        this.comparators = new TypeComparator[toClone.comparators.length];

        for (int i = 0; i < toClone.comparators.length; i++) {
            this.comparators[i] = toClone.comparators[i].duplicate();
        }

        this.normalizedKeyLengths = toClone.normalizedKeyLengths;
        this.numLeadingNormalizableKeys = toClone.numLeadingNormalizableKeys;
        this.normalizableKeyPrefixLen = toClone.normalizableKeyPrefixLen;
        this.invertNormKey = toClone.invertNormKey;

        this.type = toClone.type;

        try {
            this.serializer =
                    (TypeSerializer<T>)
                            InstantiationUtil.deserializeObject(
                                    InstantiationUtil.serializeObject(toClone.serializer),
                                    Thread.currentThread().getContextClassLoader());
        } catch (IOException | ClassNotFoundException e) {
            throw new RuntimeException("Cannot copy serializer", e);
        }
    }

    private void writeObject(ObjectOutputStream out) throws IOException, ClassNotFoundException {
        out.defaultWriteObject();
        out.writeInt(keyFields.length);
        for (Field field : keyFields) {
            FieldSerializer.serializeField(field, out);
        }
    }

    private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException {
        in.defaultReadObject();
        int numKeyFields = in.readInt();
        keyFields = new Field[numKeyFields];
        for (int i = 0; i < numKeyFields; i++) {
            keyFields[i] = FieldSerializer.deserializeField(in);
        }
    }

    public Field[] getKeyFields() {
        return this.keyFields;
    }

    @SuppressWarnings({"rawtypes", "unchecked"})
    @Override
    public void getFlatComparator(List<TypeComparator> flatComparators) {
        for (int i = 0; i < comparators.length; i++) {
            if (comparators[i] instanceof CompositeTypeComparator) {
                ((CompositeTypeComparator) comparators[i]).getFlatComparator(flatComparators);
            } else {
                flatComparators.add(comparators[i]);
            }
        }
    }

    /** This method is handling the IllegalAccess exceptions of Field.get() */
    public final Object accessField(Field field, Object object) {
        try {
            object = field.get(object);
        } catch (NullPointerException npex) {
            throw new NullKeyFieldException(
                    "Unable to access field " + field + " on object " + object);
        } catch (IllegalAccessException iaex) {
            throw new RuntimeException(
                    "This should not happen since we call setAccesssible(true) in the ctor."
                            + " fields: "
                            + field
                            + " obj: "
                            + object);
        }
        return object;
    }

    @Override
    public int hash(T value) {
        int i = 0;
        int code = 0;
        for (; i < this.keyFields.length; i++) {
            code *= TupleComparatorBase.HASH_SALT[i & 0x1F];
            try {
                code += this.comparators[i].hash(accessField(keyFields[i], value));
            } catch (NullPointerException npe) {
                throw new RuntimeException(
                        "A NullPointerException occurred while accessing a key field in a POJO. "
                                + "Most likely, the value grouped/joined on is null. Field name: "
                                + keyFields[i].getName(),
                        npe);
            }
        }
        return code;
    }

    @Override
    public void setReference(T toCompare) {
        int i = 0;
        for (; i < this.keyFields.length; i++) {
            this.comparators[i].setReference(accessField(keyFields[i], toCompare));
        }
    }

    @Override
    public boolean equalToReference(T candidate) {
        int i = 0;
        for (; i < this.keyFields.length; i++) {
            if (!this.comparators[i].equalToReference(accessField(keyFields[i], candidate))) {
                return false;
            }
        }
        return true;
    }

    @Override
    public int compareToReference(TypeComparator<T> referencedComparator) {
        PojoComparator<T> other = (PojoComparator<T>) referencedComparator;

        int i = 0;
        try {
            for (; i < this.keyFields.length; i++) {
                int cmp = this.comparators[i].compareToReference(other.comparators[i]);
                if (cmp != 0) {
                    return cmp;
                }
            }
            return 0;
        } catch (NullPointerException npex) {
            throw new NullKeyFieldException(this.keyFields[i].toString());
        }
    }

    @Override
    public int compare(T first, T second) {
        int i = 0;
        for (; i < keyFields.length; i++) {
            int cmp =
                    comparators[i].compare(
                            accessField(keyFields[i], first), accessField(keyFields[i], second));
            if (cmp != 0) {
                return cmp;
            }
        }

        return 0;
    }

    @Override
    public int compareSerialized(DataInputView firstSource, DataInputView secondSource)
            throws IOException {
        T first = this.serializer.createInstance();
        T second = this.serializer.createInstance();

        first = this.serializer.deserialize(first, firstSource);
        second = this.serializer.deserialize(second, secondSource);

        return this.compare(first, second);
    }

    @Override
    public boolean supportsNormalizedKey() {
        return this.numLeadingNormalizableKeys > 0;
    }

    @Override
    public int getNormalizeKeyLen() {
        return this.normalizableKeyPrefixLen;
    }

    @Override
    public boolean isNormalizedKeyPrefixOnly(int keyBytes) {
        return this.numLeadingNormalizableKeys < this.keyFields.length
                || this.normalizableKeyPrefixLen == Integer.MAX_VALUE
                || this.normalizableKeyPrefixLen > keyBytes;
    }

    @Override
    public void putNormalizedKey(T value, MemorySegment target, int offset, int numBytes) {
        int i = 0;
        for (; i < this.numLeadingNormalizableKeys && numBytes > 0; i++) {
            int len = this.normalizedKeyLengths[i];
            len = numBytes >= len ? len : numBytes;
            this.comparators[i].putNormalizedKey(
                    accessField(keyFields[i], value), target, offset, len);
            numBytes -= len;
            offset += len;
        }
    }

    @Override
    public boolean invertNormalizedKey() {
        return this.invertNormKey;
    }

    @Override
    public boolean supportsSerializationWithKeyNormalization() {
        return false;
    }

    @Override
    public void writeWithKeyNormalization(T record, DataOutputView target) throws IOException {
        throw new UnsupportedOperationException();
    }

    @Override
    public T readWithKeyDenormalization(T reuse, DataInputView source) throws IOException {
        throw new UnsupportedOperationException();
    }

    @Override
    public PojoComparator<T> duplicate() {
        return new PojoComparator<T>(this);
    }

    @Override
    public int extractKeys(Object record, Object[] target, int index) {
        int localIndex = index;
        for (int i = 0; i < comparators.length; i++) {
            localIndex +=
                    comparators[i].extractKeys(
                            accessField(keyFields[i], record), target, localIndex);
        }
        return localIndex - index;
    }

    // --------------------------------------------------------------------------------------------
}
